{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tuning Supervisado con SFTTrainer\n",
    "\n",
    "Este notebook demuestra cómo realizar fine-tuning del modelo `HuggingFaceTB/SmolLM2-135M` usando el `SFTTrainer` de la librería `trl`. Las celdas del notebook se ejecutan y harán el fine-tuning del modelo. Puedes seleccionar tu nivel de dificultad probando diferentes conjuntos de datos.\n",
    "\n",
    "<div style='background-color: lightblue; padding: 10px; border-radius: 5px; margin-bottom: 20px; color:black'>\n",
    "    <h2 style='margin: 0;color:blue'>Ejercicio: Fine-Tuning de SmolLM2 con SFTTrainer</h2>\n",
    "    <p>Toma un conjunto de datos desde el repositorio de Hugging Face y realiza un fine-tuning de un modelo con él.</p>\n",
    "    <p><b>Niveles de dificultad</b></p>\n",
    "    <p>🐢 Usa el conjunto de datos `HuggingFaceTB/smoltalk`</p>\n",
    "    <p>🐕 Prueba el conjunto de datos `bigcode/the-stack-smol` y realiza un fine-tuning de un modelo de generación de código en un subconjunto específico `data/python`.</p>\n",
    "    <p>🦁 Selecciona un conjunto de datos que esté relacionado con un caso de uso del mundo real que te interese.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instalar los requisitos en Google Colab\n",
    "# !pip install transformers datasets trl huggingface_hub\n",
    "\n",
    "\n",
    "# Autenticación en Hugging Face\n",
    "from huggingface_hub import login\n",
    "\n",
    "login()\n",
    "\n",
    "# Para mayor comodidad, puedes crear una variable de entorno con tu token de Hugging Face como HF_TOKEN en tu archivo .env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importa las bibliotecas necesarias\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "from trl import SFTConfig, SFTTrainer, setup_chat_format\n",
    "import torch\n",
    "\n",
    "# Establece dinámicamente el dispositivo de procesamiento.\n",
    "device = (\n",
    "    \"cuda\"\n",
    "    if torch.cuda.is_available()\n",
    "    else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    ")\n",
    "\n",
    "# Carga el modelo y el tokenizador\n",
    "model_name = \"HuggingFaceTB/SmolLM2-135M\"\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    pretrained_model_name_or_path=model_name\n",
    ").to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)\n",
    "\n",
    "# Configura el formato de chat\n",
    "model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)\n",
    "\n",
    "# Establece nuestro nombre para guardar y/o subir el modelo finetuneado\n",
    "finetune_name = \"SmolLM2-FT-MyDataset\"\n",
    "finetune_tags = [\"smol-course\", \"module_1\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generar con el modelo base\n",
    "\n",
    "Aquí probaremos el modelo base, que no tiene una plantilla de chat."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Probemos el modelo base antes de entrenar\n",
    "prompt = \"Write a haiku about programming\"\n",
    "\n",
    "# Formatea con la plantilla\n",
    "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
    "formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)\n",
    "\n",
    "# Genera la respuesta\n",
    "inputs = tokenizer(formatted_prompt, return_tensors=\"pt\").to(device)\n",
    "outputs = model.generate(**inputs, max_new_tokens=100)\n",
    "print(\"Before training:\")\n",
    "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparación del Conjunto de Datos\n",
    "\n",
    "Cargaremos un conjunto de datos de ejemplo y lo formatearemos para el entrenamiento. El conjunto de datos debe estar estructurado con pares de entrada-salida, donde cada entrada es un prompt y la salida es la respuesta esperada del modelo.\n",
    "\n",
    "**TRL formateará los mensajes de entrada según las plantillas de chat del modelo.** Estos deben representarse como una lista de diccionarios con las claves: `role` y `content`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Carga un conjunto de datos de ejemplo\n",
    "from datasets import load_dataset\n",
    "\n",
    "# TODO: define tu conjunto de datos y configuración usando los parámetros de ruta y nombre\n",
    "ds = load_dataset(path=\"HuggingFaceTB/smoltalk\", name=\"everyday-conversations\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: 🦁 Si tu conjunto de datos no está en un formato que TRL pueda convertir a la plantilla de chat, necesitarás procesarlo. Consulta el [módulo](../chat_templates.md)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuración del SFTTrainer\n",
    "\n",
    "El `SFTTrainer` se configura con varios parámetros que controlan el proceso de entrenamiento. Estos incluyen el número de pasos de entrenamiento, el tamaño del lote, la tasa de aprendizaje y la estrategia de evaluación. Ajusta estos parámetros según tus requisitos específicos y recursos computacionales."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configura el SFTTrainer\n",
    "sft_config = SFTConfig(\n",
    "    output_dir=\"./sft_output\",  # Directorio de salida para los resultados del entrenamiento\n",
    "    max_steps=1000,  # Ajusta según el tamaño del conjunto de datos y la duración de entrenamiento deseada\n",
    "    per_device_train_batch_size=4,  # Ajusta según la capacidad de la memoria de tu GPU\n",
    "    learning_rate=5e-5,  # Punto de partida común para el fine-tuning\n",
    "    logging_steps=10,  # Frecuencia de registro de métricas de entrenamiento\n",
    "    save_steps=100,  # Frecuencia de guardado de puntos de control del modelo\n",
    "    evaluation_strategy=\"steps\",  # Evalua el modelo en intervalos regulares\n",
    "    eval_steps=50,  # Frecuencia de evaluación\n",
    "    use_mps_device=(\n",
    "        True if device == \"mps\" else False\n",
    "    ),  # Usa MPS para entrenamiento con precisión mixta\n",
    "    hub_model_id=finetune_name,  # Asigna un nombre único a tu modelo\n",
    ")\n",
    "\n",
    "# Inicializa el SFTTrainer\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    args=sft_config,\n",
    "    train_dataset=ds[\"train\"],\n",
    "    tokenizer=tokenizer,\n",
    "    eval_dataset=ds[\"test\"],\n",
    ")\n",
    "\n",
    "# TODO: 🦁 🐕 ajusta los parámetros del SFTTrainer a tu conjunto de datos elegido. Por ejemplo, si usas el conjunto de datos `bigcode/the-stack-smol`, necesitarás elegir la columna `content`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Entrenamiento del Modelo\n",
    "\n",
    "Con el entrenador configurado, ahora podemos proceder a entrenar el modelo. El proceso de entrenamiento implicará iterar sobre el conjunto de datos, calcular la pérdida y actualizar los parámetros del modelo para minimizar esta pérdida."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Entrena el modelo\n",
    "trainer.train()\n",
    "\n",
    "# Guarda el modelo\n",
    "trainer.save_model(f\"./{finetune_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub(tags=finetune_tags)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style='background-color: lightblue; padding: 10px; border-radius: 5px; margin-bottom: 20px; color:black'>\n",
    "    <h2 style='margin: 0;color:blue'>Ejercicio adicional: Generar con el modelo ajustado</h2>\n",
    "    <p>🐕 Utiliza el modelo ajustado para generar una respuesta, de la misma forma que con el ejemplo base.</p>\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prueba el modelo ajustado con el mismo mensaje\n",
    "\n",
    "# Probemos el modelo base antes del entrenamiento\n",
    "prompt = \"Escribe un haiku sobre programación\"\n",
    "\n",
    "# Formatea con la plantilla\n",
    "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
    "formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)\n",
    "\n",
    "# Genera respuesta\n",
    "inputs = tokenizer(formatted_prompt, return_tensors=\"pt\").to(device)\n",
    "\n",
    "# TODO: usa el modelo ajustado para generar una respuesta, igual que en el ejemplo base."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 💐 ¡Has terminado!\n",
    "\n",
    "Este notebook proporcionó una guía paso a paso para realizar fine-tuning del modelo `HuggingFaceTB/SmolLM2-135M` utilizando el `SFTTrainer`. Siguiendo estos pasos, puedes adaptar el modelo para realizar tareas específicas de manera más efectiva. Si deseas continuar trabajando en este curso, aquí tienes algunas sugerencias que podrías intentar:\n",
    "\n",
    "- Intenta este notebook en un nivel de dificultad más alto.\n",
    "- Revisa el PR de un compañero.\n",
    "- Mejora el material del curso a través de un Issue o PR.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
